Skip to content

Conversation

tjburch
Copy link

@tjburch tjburch commented May 23, 2025

Closes #990

This PR implements support for variable fold weights in hyperparameter tuning. This is useful in cases where folds may have differing numbers of observations, and you want proportional contribution to hyperparameter selection.

The implementation adds two main functions: add_fold_weights() to attach custom weights to rset objects, and calculate_fold_weights() to automatically compute weights proportional to fold sizes. Weights are stored as .fold_weights attributes and should flow through the existing tuning pipeline.

Core changes are in estimate_tune_results() which now detects weights and uses weighted statistics (weighted mean, weighted standard deviation, effective sample size) when aggregating metrics. Implementation should be backwards compatible and non-breaking.

@topepo
Copy link
Member

topepo commented Jun 4, 2025

Hey @tjburch. Thanks for the PR.

We're doing a pretty invasive update the this package that will take some time. We'll look at the PR after things are settled there but it might be another 2-3 weeks.

Is this something time-sensitive for you?

@tjburch
Copy link
Author

tjburch commented Jun 4, 2025

Nope. Just had some bandwidth staying awake on paternity leave. Review at your leisure, let me know if I can assist otherwise.

@topepo
Copy link
Member

topepo commented Jun 4, 2025

Just had some bandwidth staying awake on paternity leave

I estimate that 5% of all my work has been while waiting at the bustop or for some sort of practice to end 😄

@tjburch tjburch marked this pull request as draft July 23, 2025 02:01
@tjburch
Copy link
Author

tjburch commented Jul 23, 2025

Got around to deploying this into a local project and running into some odd errors. Will circle back.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

air

[air] reported by reviewdog 🐶

tune/R/utils.R

Line 393 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 399 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 408 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 411 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 414 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 419 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 427 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 430 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 436 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 440 in 27a5ab5

#'


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 466 in 27a5ab5


[air] reported by reviewdog 🐶

tune/R/utils.R

Line 481 in 27a5ab5


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

unweighted_results <- fit_resamples(simple_wflow, folds,
control = control_resamples(save_pred = FALSE))
weighted_results_equal <- fit_resamples(simple_wflow, weighted_folds_equal,
control = control_resamples(save_pred = FALSE))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_equal(unweighted_metrics$mean, weighted_metrics_equal$mean, tolerance = 1e-10)


[air] reported by reviewdog 🐶

unequal_weights <- c(0.1, 0.3, 0.6) # Higher weight on last fold


[air] reported by reviewdog 🐶

weighted_results_unequal <- fit_resamples(simple_wflow, weighted_folds_unequal,
control = control_resamples(save_pred = FALSE))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_false(all(abs(unweighted_metrics$mean - weighted_metrics_unequal$mean) < 1e-10))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_equal(sum(calculated_weights), 1) # Should sum to 1 now


[air] reported by reviewdog 🐶

skip_if_not_installed("parsnip")


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

tune_mod <- parsnip::linear_reg(penalty = tune()) %>%


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

weighted_tune_results <- tune_grid(tune_wflow, weighted_folds,
grid = simple_grid,
control = control_grid(save_pred = FALSE))


[air] reported by reviewdog 🐶

expect_s3_class(weighted_tune_results, "tune_results")


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

unweighted_tune_results <- tune_grid(tune_wflow, folds,
grid = simple_grid,
control = control_grid(save_pred = FALSE))


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_false(all(abs(weighted_metrics$mean - unweighted_metrics$mean) < 1e-10))


[air] reported by reviewdog 🐶

if (rlang::is_installed(c("rsample", "parsnip", "yardstick", "workflows", "recipes", "kknn"))) {


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(1/6, 1/3, 1/2) # normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(0.2, 0.3, 0.5) # already normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

resamples = folds, # No weights


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expected_weights <- weights / sum(weights) # normalized


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

splits1 <- rsample::make_splits(x = mtcars[1:20,], assessment = mtcars[21:32,])
splits2 <- rsample::make_splits(x = mtcars[1:15,], assessment = mtcars[16:32,])


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

air

[air] reported by reviewdog 🐶

expect_false(all(abs(weighted_metrics$mean - unweighted_metrics$mean) < 1e-10))


[air] reported by reviewdog 🐶

if (rlang::is_installed(c("rsample", "parsnip", "yardstick", "workflows", "recipes", "kknn"))) {


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(1/6, 1/3, 1/2) # normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

c(0.2, 0.3, 0.5) # already normalized to sum to 1


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

resamples = folds, # No weights


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

expected_weights <- weights / sum(weights) # normalized


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

splits1 <- rsample::make_splits(x = mtcars[1:20,], assessment = mtcars[21:32,])
splits2 <- rsample::make_splits(x = mtcars[1:15,], assessment = mtcars[16:32,])


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶


[air] reported by reviewdog 🐶

@tjburch
Copy link
Author

tjburch commented Sep 11, 2025

Sorry for the noise here - didn't realize Air was part of ci/cd now.

fwiw I've been using my fork here for a while in a production process and it seems fine. I'm going to dust it off and get it ready to go asap

@tjburch tjburch marked this pull request as ready for review September 12, 2025 12:21
@tjburch
Copy link
Author

tjburch commented Sep 12, 2025

Alright, I think this is back ready for review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Variable Fold Weights
2 participants